{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mejJonsu2o7r",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "#### Taken from [tensorflow-tutorial ](https://www.tensorflow.org/tutorials/text/nmt_with_attention ) and added a *translate_batch()* function to translate a batch and dump outputs into a file"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "J0Qjg6vuaHNt"
   },
   "source": [
    "# Neural machine translation with attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tnxXKDjq3jEL"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import unicodedata\n",
    "import re\n",
    "import numpy as np\n",
    "import os\n",
    "import io\n",
    "import time\n",
    "from utils.dataset import NMTDataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "YsrDDMlKA2nd",
    "outputId": "d31e50ca-bb4b-4c79-e920-7d30dac76748"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset.py\n"
     ]
    }
   ],
   "source": [
    "!ls utils\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GOi42V79Ydlr"
   },
   "source": [
    "We'll use the same dataset we worked on notebook-1 (text-processing). For our convenience we've created a utils/dataset.py file which returns train and validation tf.data.Dataset objects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "sy9mcFmxBP4J",
    "outputId": "a5f3a663-71c5-4d97-acd2-cc9249042d77"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\n",
      "2646016/2638744 [==============================] - 0s 0us/step\n"
     ]
    }
   ],
   "source": [
    "BUFFER_SIZE = 32000\n",
    "BATCH_SIZE = 64\n",
    "num_examples = 30000\n",
    "\n",
    "dataset_creator = NMTDataset('en-spa')\n",
    "train_dataset, val_dataset, inp_lang, targ_lang = dataset_creator.call(num_examples, BUFFER_SIZE, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "wVtSz89FBk0D",
    "outputId": "48d1ab71-9d23-4369-f081-12c8927814c2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inpute Vocabulary Size: 9414\n",
      "Target Vocabulary Size: 4935\n"
     ]
    }
   ],
   "source": [
    "\n",
    "print(\"Inpute Vocabulary Size: {}\".format(len(inp_lang.word_index)))\n",
    "print(\"Target Vocabulary Size: {}\".format(len(targ_lang.word_index)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "qc6-NK1GtWQt",
    "outputId": "f35ff927-5a3c-480c-95ae-2b0d63844250"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorShape([64, 16]), TensorShape([64, 11]))"
      ]
     },
     "execution_count": 9,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_input_batch, example_target_batch = next(iter(train_dataset))\n",
    "example_input_batch.shape, example_target_batch.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TNfHIF71ulLu"
   },
   "source": [
    "## Write the encoder and decoder model\n",
    "\n",
    "Implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://github.com/tensorflow/nmt). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://github.com/tensorflow/nmt#background-on-the-attention-mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence. The below picture and formulas are an example of attention mechanism from [Luong's paper](https://arxiv.org/abs/1508.04025v5). \n",
    "\n",
    "<img src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\">\n",
    "\n",
    "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*.\n",
    "\n",
    "Here are the equations that are implemented:\n",
    "\n",
    "<img src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\">\n",
    "<img src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\">\n",
    "\n",
    "This tutorial uses [Bahdanau attention](https://arxiv.org/pdf/1409.0473.pdf) for the encoder. Let's decide on notation before writing the simplified form:\n",
    "\n",
    "* FC = Fully connected (dense) layer\n",
    "* EO = Encoder output\n",
    "* H = hidden state\n",
    "* X = input to the decoder\n",
    "\n",
    "And the pseudo-code:\n",
    "\n",
    "* `score = FC(tanh(FC(EO) + FC(H)))`\n",
    "* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, hidden_size)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n",
    "* `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.\n",
    "* `embedding output` = The input to the decoder X is passed through an embedding layer.\n",
    "* `merged vector = concat(embedding output, context vector)`\n",
    "* This merged vector is then given to the GRU\n",
    "\n",
    "The shapes of all the vectors at each step have been specified in the comments in the code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "dgVsBjbVDBwX"
   },
   "outputs": [],
   "source": [
    "# Define some useful parameters for further use\n",
    "\n",
    "vocab_inp_size = len(inp_lang.word_index)+1\n",
    "vocab_tar_size = len(targ_lang.word_index)+1\n",
    "max_length_input = example_input_batch.shape[1]\n",
    "max_length_output = example_target_batch.shape[1]\n",
    "\n",
    "embedding_dim = 128\n",
    "units = 1024\n",
    "steps_per_epoch = num_examples//BATCH_SIZE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nZ2rI24i3jFg"
   },
   "outputs": [],
   "source": [
    "# Encoder is composed of embedding layer and then one GRU layer. It produces outputs and last hidden states. \n",
    "# Encoder Outputs shape = (BATCH_SIZE, max_length_input, units)\n",
    "# Last Hidden State Shape = (BATCH_SIZE, units)\n",
    "\n",
    "class Encoder(tf.keras.Model):\n",
    "  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n",
    "    super(Encoder, self).__init__()\n",
    "    self.batch_sz = batch_sz\n",
    "    self.enc_units = enc_units\n",
    "    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
    "    self.gru = tf.keras.layers.GRU(self.enc_units,\n",
    "                                   return_sequences=True,\n",
    "                                   return_state=True,\n",
    "                                   recurrent_initializer='glorot_uniform')\n",
    "\n",
    "  def call(self, x, hidden):\n",
    "    x = self.embedding(x)\n",
    "    output, state = self.gru(x, initial_state = hidden)\n",
    "    return output, state\n",
    "\n",
    "  def initialize_hidden_state(self):\n",
    "    return tf.zeros((self.batch_sz, self.enc_units))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "60gSVh05Jl6l",
    "outputId": "d5ad818c-8ef1-4f16-e9fb-b04503a5ca8a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Encoder output shape: (batch size, sequence length, units) (64, 16, 1024)\n",
      "Encoder Hidden state shape: (batch size, units) (64, 1024)\n"
     ]
    }
   ],
   "source": [
    "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
    "\n",
    "# sample input\n",
    "sample_hidden = encoder.initialize_hidden_state()\n",
    "sample_output, sample_hidden = encoder(example_input_batch, sample_hidden)\n",
    "print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))\n",
    "print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "umohpBN2OM94"
   },
   "outputs": [],
   "source": [
    "class BahdanauAttention(tf.keras.layers.Layer):\n",
    "  def __init__(self, units):\n",
    "    super(BahdanauAttention, self).__init__()\n",
    "    # To recall, score = V*tanh(W1(encoder_outputs) + W2(Prev Step's Hidden State))\n",
    "    self.W1 = tf.keras.layers.Dense(units)\n",
    "    self.W2 = tf.keras.layers.Dense(units)\n",
    "    self.V = tf.keras.layers.Dense(1)\n",
    "\n",
    "  def call(self, query, values):\n",
    "    # query hidden state shape == (batch_size, hidden size)\n",
    "    # query_with_time_axis shape == (batch_size, 1, hidden size)\n",
    "    # values shape == (batch_size, max_len, hidden size)\n",
    "    # we are doing this to broadcast addition along the time axis to calculate the score\n",
    "    query_with_time_axis = tf.expand_dims(query, 1)\n",
    "\n",
    "    # score shape == (batch_size, max_length, 1)\n",
    "    # we get 1 at the last axis because we are applying score to self.V\n",
    "    # the shape of the tensor before applying self.V is (batch_size, max_length, units)\n",
    "    \n",
    "    #-------- COMPUTING EQUATION (4)  ---------#  \n",
    "    score = self.V(tf.nn.tanh(\n",
    "        self.W1(query_with_time_axis) + self.W2(values)))\n",
    "\n",
    "    #-------- COMPUTING EQUATION (1) ------------#\n",
    "    # attention_weights shape == (batch_size, max_length, 1)\n",
    "    attention_weights = tf.nn.softmax(score, axis=1)\n",
    "\n",
    "    #---------- COMPUTING EQUATION (2) -----------#\n",
    "    # context_vector shape after sum == (batch_size, hidden_size)\n",
    "    context_vector = attention_weights * values\n",
    "    # Context vector is passed on to the curren time step's decoder cell\n",
    "    context_vector = tf.reduce_sum(context_vector, axis=1)\n",
    "\n",
    "    return context_vector, attention_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "k534zTHiDjQU",
    "outputId": "6b662027-8c36-4798-f1a0-0f186f8dbc85"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Attention result shape: (batch size, units) (64, 1024)\n",
      "Attention weights shape: (batch_size, sequence_length, 1) (64, 16, 1)\n"
     ]
    }
   ],
   "source": [
    "attention_layer = BahdanauAttention(10)\n",
    "attention_result, attention_weights = attention_layer(sample_hidden, sample_output)\n",
    "\n",
    "print(\"Attention result shape: (batch size, units) {}\".format(attention_result.shape))\n",
    "print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "yJ_B3mhW3jFk"
   },
   "outputs": [],
   "source": [
    "class Decoder(tf.keras.Model):\n",
    "  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
    "    super(Decoder, self).__init__()\n",
    "    self.batch_sz = batch_sz\n",
    "    self.dec_units = dec_units\n",
    "    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
    "    self.gru = tf.keras.layers.GRU(self.dec_units,\n",
    "                                   return_sequences=True,\n",
    "                                   return_state=True,\n",
    "                                   recurrent_initializer='glorot_uniform')\n",
    "    self.fc = tf.keras.layers.Dense(vocab_size)\n",
    "\n",
    "    # used for attention\n",
    "    self.attention = BahdanauAttention(self.dec_units)\n",
    "\n",
    "  def call(self, x, hidden, enc_output):\n",
    "    # enc_output shape == (batch_size, max_length, hidden_size)\n",
    "    context_vector, attention_weights = self.attention(hidden, enc_output)\n",
    "\n",
    "    # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
    "    x = self.embedding(x)\n",
    "\n",
    "    #------------ COMPUTING EQUATION (3) ----------------------#\n",
    "    # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
    "    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
    "\n",
    "    # passing the concatenated vector to the GRU\n",
    "    output, state = self.gru(x)\n",
    "\n",
    "    # output shape == (batch_size * 1, hidden_size)\n",
    "    output = tf.reshape(output, (-1, output.shape[2]))\n",
    "\n",
    "    # output shape == (batch_size, vocab)\n",
    "    x = self.fc(output)\n",
    "\n",
    "    return x, state, attention_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "P5UY8wko3jFp",
    "outputId": "fee4a9eb-3040-47f7-b990-d3808e940516"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decoder output shape: (batch_size, vocab size) (64, 4936)\n"
     ]
    }
   ],
   "source": [
    "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n",
    "\n",
    "sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),\n",
    "                                      sample_hidden, sample_output)\n",
    "\n",
    "print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_ch_71VbIRfK"
   },
   "source": [
    "## Define the optimizer and the loss function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WmTHr5iV3jFr"
   },
   "outputs": [],
   "source": [
    "optimizer = tf.keras.optimizers.Adam()\n",
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
    "    from_logits=True, reduction='none')\n",
    "\n",
    "def loss_function(real, pred):\n",
    "  mask = tf.math.logical_not(tf.math.equal(real, 0))\n",
    "  loss_ = loss_object(real, pred)\n",
    "\n",
    "  mask = tf.cast(mask, dtype=loss_.dtype)\n",
    "  loss_ *= mask\n",
    "\n",
    "  return tf.reduce_mean(loss_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DMVWzzsfNl4e"
   },
   "source": [
    "## Checkpoints (Object-based saving)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Zj8bXQTgNwrF"
   },
   "outputs": [],
   "source": [
    "checkpoint_dir = './training_checkpoints'\n",
    "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
    "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n",
    "                                 encoder=encoder,\n",
    "                                 decoder=decoder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hpObfY22IddU"
   },
   "source": [
    "## Training\n",
    "\n",
    "1. Pass the *input* through the *encoder* which return *encoder output* and the *encoder hidden state*.\n",
    "2. The encoder output, encoder hidden state and the decoder input (which is the *start token*) is passed to the decoder.\n",
    "3. The decoder returns the *predictions* and the *decoder hidden state*.\n",
    "4. The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n",
    "5. Use *teacher forcing* to decide the next input to the decoder.\n",
    "6. *Teacher forcing* is the technique where the *target word* is passed as the *next input* to the decoder.\n",
    "7. The final step is to calculate the gradients and apply it to the optimizer and backpropagate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "sC9ArXSsVfqn"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(inp, targ, enc_hidden):\n",
    "  loss = 0\n",
    "\n",
    "  with tf.GradientTape() as tape:\n",
    "    enc_output, enc_hidden = encoder(inp, enc_hidden)\n",
    "\n",
    "    dec_hidden = enc_hidden\n",
    "\n",
    "    dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)\n",
    "\n",
    "    # Teacher forcing - feeding the target as the next input\n",
    "    for t in range(1, targ.shape[1]):\n",
    "      # passing enc_output to the decoder\n",
    "      predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n",
    "\n",
    "      loss += loss_function(targ[:, t], predictions)\n",
    "\n",
    "      # using teacher forcing\n",
    "      dec_input = tf.expand_dims(targ[:, t], 1)\n",
    "\n",
    "  batch_loss = (loss / int(targ.shape[1]))\n",
    "\n",
    "  variables = encoder.trainable_variables + decoder.trainable_variables\n",
    "\n",
    "  gradients = tape.gradient(loss, variables)\n",
    "\n",
    "  optimizer.apply_gradients(zip(gradients, variables))\n",
    "\n",
    "  return batch_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "ddefjBMa3jF0",
    "outputId": "b9588938-68d5-4ac1-a2cd-d6f07d3d2b57"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Batch 0 Loss 4.6625\n",
      "Epoch 1 Batch 100 Loss 2.3375\n",
      "Epoch 1 Batch 200 Loss 1.9375\n",
      "Epoch 1 Batch 300 Loss 1.8599\n",
      "Epoch 1 Loss 1.6788\n",
      "Time taken for 1 epoch 38.504894733428955 sec\n",
      "\n",
      "Epoch 2 Batch 0 Loss 1.6332\n",
      "Epoch 2 Batch 100 Loss 1.6223\n",
      "Epoch 2 Batch 200 Loss 1.4481\n",
      "Epoch 2 Batch 300 Loss 1.4766\n",
      "Epoch 2 Loss 1.1835\n",
      "Time taken for 1 epoch 27.77025604248047 sec\n",
      "\n",
      "Epoch 3 Batch 0 Loss 1.3868\n",
      "Epoch 3 Batch 100 Loss 1.2430\n",
      "Epoch 3 Batch 200 Loss 1.1446\n",
      "Epoch 3 Batch 300 Loss 1.0991\n",
      "Epoch 3 Loss 0.9428\n",
      "Time taken for 1 epoch 27.430588245391846 sec\n",
      "\n",
      "Epoch 4 Batch 0 Loss 0.8816\n",
      "Epoch 4 Batch 100 Loss 0.8683\n",
      "Epoch 4 Batch 200 Loss 0.9337\n",
      "Epoch 4 Batch 300 Loss 0.9032\n",
      "Epoch 4 Loss 0.7303\n",
      "Time taken for 1 epoch 28.03623127937317 sec\n",
      "\n",
      "Epoch 5 Batch 0 Loss 0.7077\n",
      "Epoch 5 Batch 100 Loss 0.5876\n",
      "Epoch 5 Batch 200 Loss 0.7211\n",
      "Epoch 5 Batch 300 Loss 0.5663\n",
      "Epoch 5 Loss 0.5436\n",
      "Time taken for 1 epoch 27.8458194732666 sec\n",
      "\n",
      "Epoch 6 Batch 0 Loss 0.4867\n",
      "Epoch 6 Batch 100 Loss 0.4494\n",
      "Epoch 6 Batch 200 Loss 0.5099\n",
      "Epoch 6 Batch 300 Loss 0.4105\n",
      "Epoch 6 Loss 0.3901\n",
      "Time taken for 1 epoch 28.244993925094604 sec\n",
      "\n",
      "Epoch 7 Batch 0 Loss 0.2788\n",
      "Epoch 7 Batch 100 Loss 0.3070\n",
      "Epoch 7 Batch 200 Loss 0.3527\n",
      "Epoch 7 Batch 300 Loss 0.3509\n",
      "Epoch 7 Loss 0.2716\n",
      "Time taken for 1 epoch 27.91599726676941 sec\n",
      "\n",
      "Epoch 8 Batch 0 Loss 0.2192\n",
      "Epoch 8 Batch 100 Loss 0.2306\n",
      "Epoch 8 Batch 200 Loss 0.2139\n",
      "Epoch 8 Batch 300 Loss 0.1997\n",
      "Epoch 8 Loss 0.1863\n",
      "Time taken for 1 epoch 28.36590576171875 sec\n",
      "\n",
      "Epoch 9 Batch 0 Loss 0.1307\n",
      "Epoch 9 Batch 100 Loss 0.1158\n",
      "Epoch 9 Batch 200 Loss 0.1888\n",
      "Epoch 9 Batch 300 Loss 0.1652\n",
      "Epoch 9 Loss 0.1286\n",
      "Time taken for 1 epoch 28.02374768257141 sec\n",
      "\n",
      "Epoch 10 Batch 0 Loss 0.0939\n",
      "Epoch 10 Batch 100 Loss 0.0976\n",
      "Epoch 10 Batch 200 Loss 0.1271\n",
      "Epoch 10 Batch 300 Loss 0.1284\n",
      "Epoch 10 Loss 0.0920\n",
      "Time taken for 1 epoch 28.397955179214478 sec\n",
      "\n"
     ]
    }
   ],
   "source": [
    "EPOCHS = 10\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "  start = time.time()\n",
    "\n",
    "  enc_hidden = encoder.initialize_hidden_state()\n",
    "  total_loss = 0\n",
    "\n",
    "  for (batch, (inp, targ)) in enumerate(train_dataset.take(steps_per_epoch)):\n",
    "    batch_loss = train_step(inp, targ, enc_hidden)\n",
    "    total_loss += batch_loss\n",
    "\n",
    "    if batch % 100 == 0:\n",
    "      print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
    "                                                   batch,\n",
    "                                                   batch_loss.numpy()))\n",
    "  # saving (checkpoint) the model every 2 epochs\n",
    "  if (epoch + 1) % 2 == 0:\n",
    "    checkpoint.save(file_prefix = checkpoint_prefix)\n",
    "\n",
    "  print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
    "                                      total_loss / steps_per_epoch))\n",
    "  print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mU3Ce8M6I3rz"
   },
   "source": [
    "## Translate\n",
    "\n",
    "* The evaluate function is similar to the training loop, except we don't use *teacher forcing* here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n",
    "* Stop predicting when the model predicts the *end token*.\n",
    "* And store the *attention weights for every time step*.\n",
    "\n",
    "Note: The encoder output is calculated only once for one input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "EbQpyYs13jF_"
   },
   "outputs": [],
   "source": [
    "def evaluate(sentence):\n",
    "\n",
    "  sentence = dataset_creator.preprocess_sentence(sentence)\n",
    "\n",
    "  inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]\n",
    "  inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],\n",
    "                                                         maxlen=max_length_input,\n",
    "                                                         padding='post')\n",
    "  inputs = tf.convert_to_tensor(inputs)\n",
    "\n",
    "  result = ''\n",
    "\n",
    "  hidden = [tf.zeros((1, units))]\n",
    "  enc_out, enc_hidden = encoder(inputs, hidden)\n",
    "\n",
    "  dec_hidden = enc_hidden\n",
    "  dec_input = tf.expand_dims([targ_lang.word_index['<start>']], 0)\n",
    "\n",
    "  for t in range(max_length_output):\n",
    "    predictions, dec_hidden, attention_weights = decoder(dec_input,\n",
    "                                                         dec_hidden,\n",
    "                                                         enc_out)\n",
    "\n",
    "    # storing the attention weights to plot later on\n",
    "\n",
    "    predicted_id = tf.argmax(predictions[0]).numpy()\n",
    "\n",
    "    result += targ_lang.index_word[predicted_id] + ' '\n",
    "\n",
    "    if targ_lang.index_word[predicted_id] == '<end>':\n",
    "      return result, sentence\n",
    "\n",
    "    # the predicted ID is fed back into the model\n",
    "    dec_input = tf.expand_dims([predicted_id], 0)\n",
    "\n",
    "  return result, sentence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "sl9zUHzg3jGI"
   },
   "outputs": [],
   "source": [
    "def translate(sentence):\n",
    "  result, sentence = evaluate(sentence)\n",
    "\n",
    "  print('Input: %s' % (sentence))\n",
    "  print('Predicted translation: {}'.format(result))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "n250XbnjOaqP"
   },
   "source": [
    "## Restore the latest checkpoint and test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "UJpT9D5_OgP6",
    "outputId": "d1628fc9-1681-48c1-840e-61677a7e49c3"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f282acbe5f8>"
      ]
     },
     "execution_count": 29,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# restoring the latest checkpoint in checkpoint_dir\n",
    "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "WrAM0FDomq3E",
    "outputId": "5104049f-b169-42dd-a26d-63eedcceda2d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: <start> hace mucho frio aqui . <end>\n",
      "Predicted translation: it s very cold here . <end> \n"
     ]
    }
   ],
   "source": [
    "translate(u'hace mucho frio aqui.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "zSx2iM36EZQZ",
    "outputId": "5b8dc076-acb5-4d62-d66f-13f8d0234adb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: <start> esta es mi vida . <end>\n",
      "Predicted translation: this is my life . <end> \n"
     ]
    }
   ],
   "source": [
    "translate(u'esta es mi vida.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "A3LLCx3ZE0Ls",
    "outputId": "8c970bcf-97d8-480d-ed79-cfc80c757939"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: <start> ¿ todavia estan en casa ? <end>\n",
      "Predicted translation: are you home ? <end> \n"
     ]
    }
   ],
   "source": [
    "translate(u'¿todavia estan en casa?')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "DUQVLVqUE1YW",
    "outputId": "64f3c639-9bd1-41dc-de13-b7cd7f0e3c03"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: <start> trata de averiguarlo . <end>\n",
      "Predicted translation: try to find out . <end> \n"
     ]
    }
   ],
   "source": [
    "# wrong translation\n",
    "translate(u'trata de averiguarlo.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "92aXUSuvwzOt"
   },
   "outputs": [],
   "source": [
    "def translate_batch(test_dataset):\n",
    "  with open('output_text.txt', 'w') as f:\n",
    "    for (inputs, targets) in test_dataset:\n",
    "      outputs = np.zeros((BATCH_SIZE, max_length_output),dtype=np.int16)\n",
    "      hidden_state = tf.zeros((BATCH_SIZE, units))\n",
    "      enc_output, dec_h = encoder(inputs, hidden_state)\n",
    "      dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)\n",
    "      for t in range(max_length_output):\n",
    "        preds, dec_h, _ = decoder(dec_input, dec_h, enc_output)\n",
    "        predicted_id = tf.argmax(preds, axis=1).numpy()\n",
    "        outputs[:, t] = predicted_id\n",
    "        dec_input = tf.expand_dims(predicted_id, 1)\n",
    "      outputs = targ_lang.sequences_to_texts(outputs)\n",
    "      for t, item in enumerate(outputs):\n",
    "        try:\n",
    "          i = item.index('<end>')\n",
    "          f.write(\"%s\\n\" %item[:i])\n",
    "        except: \n",
    "          f.write(\"%s \\n\" % item) # For those translated sequences which didn't correctly translated and have <end> token.\n",
    "\n",
    "outputs = translate_batch(val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204
    },
    "colab_type": "code",
    "id": "rbHSSD1dT1D0",
    "outputId": "9876c66c-90e1-43f4-ba4f-8f9bebddf894"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "walk slowly . \n",
      "i ll keep my fate . \n",
      "welcome back . \n",
      "i love to travel . \n",
      "don t touch me ! \n",
      "hold the rope . \n",
      "fantastic ! \n",
      "tom betrayed you . \n",
      "get it up . \n",
      "i was very nervous . \n",
      "5952 output_text.txt\n"
     ]
    }
   ],
   "source": [
    "!head output_text.txt\n",
    "! wc -l output_text.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204
    },
    "colab_type": "code",
    "id": "t2Ci8zsFbL8q",
    "outputId": "a90041ce-b3fa-449c-a410-1ccc08f659ff"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['<start> walk more slowly . <end> <OOV> <OOV> <OOV> <OOV> <OOV>',\n",
       " '<start> i ll have my revenge . <end> <OOV> <OOV> <OOV>',\n",
       " '<start> welcome back . <end> <OOV> <OOV> <OOV> <OOV> <OOV> <OOV>',\n",
       " '<start> i love reading . <end> <OOV> <OOV> <OOV> <OOV> <OOV>',\n",
       " '<start> don t push me ! <end> <OOV> <OOV> <OOV> <OOV>',\n",
       " '<start> he let go of the rope . <end> <OOV> <OOV>',\n",
       " '<start> run ! <end> <OOV> <OOV> <OOV> <OOV> <OOV> <OOV> <OOV>',\n",
       " '<start> tom betrayed you . <end> <OOV> <OOV> <OOV> <OOV> <OOV>',\n",
       " '<start> hold on . <end> <OOV> <OOV> <OOV> <OOV> <OOV> <OOV>',\n",
       " '<start> he was very nervous . <end> <OOV> <OOV> <OOV> <OOV>']"
      ]
     },
     "execution_count": 72,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_targets = list(val_dataset.take(1))\n",
    "val_targets = np.asarray(val_targets[0][1])\n",
    "print(type(val_targets))\n",
    "targ_lang.sequences_to_texts(val_targets)[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kk7sbSx9c1a9"
   },
   "source": [
    "We can see that the model worked well. Despite being not very accurate the translations, however do make some sense. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "okmWWHmMc8T6"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "4_enc_dec_with_BahdanauAttention.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
